package org.zjvis.datascience.web.socket;

import com.alibaba.fastjson.JSONObject;
import org.apache.curator.shaded.com.google.common.collect.Maps;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.Map;

/**
 * @description WebSocket 登录相关Endpoint
 * @date 2021-11-15
 */
@ServerEndpoint(value = "/miserver/{userId}")
@Component
public class WebSocket {

    private static int onlineCount = 0;

    private static Map<String, WebSocket> clients = Maps.newConcurrentMap();

    private Session session;

    private String userId;

    @OnOpen
    public void onOpen(@PathParam("userId") String userId, Session session) throws IOException {
        this.userId = userId;
        this.session = session;
        addOnlineCount();
        clients.put(userId, this);

        sendMsg(userId, userId + "已连接");
        System.out.println(userId + "已连接");
    }

    @OnClose
    public void onClose() throws IOException {
        clients.remove(userId);
        subOnlineCount();

        sendMsg(userId, userId + "已退出");
        System.out.println(userId + "已退出");
    }

    public static void sendMsg(String userId, String message) throws IOException {
        WebSocket socket = clients.get(userId);
        if (socket != null) {
            socket.session.getAsyncRemote().sendText(message);
        }
    }

    @OnMessage
    public void onMessage(String message, Session session) throws IOException {
        JSONObject json = JSONObject.parseObject(message);
        String msg = json.getString("message");
        String userId = json.getString("userId");

        sendMsg(userId, "收到" + userId + "的消息:" + msg);
        System.out.println("收到" + userId + "的消息:" + msg);
//    if ("All".equals(userId)) {
//      for (WebSocket socket : clients.values()) {
//        socket.session.getAsyncRemote().sendText(message);
//      }
//    } else {
//      WebSocket socket = clients.get(userId);
//      if (socket != null) {
//        socket.session.getAsyncRemote().sendText(message);
//      }
//    }
    }

    @OnError
    public void onError(Session session, Throwable error) {
        error.printStackTrace();
        System.err.println("连接错误, " + error.getMessage());
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocket.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocket.onlineCount--;
    }

    public static synchronized Map<String, WebSocket> getClients() {
        return clients;
    }

    public static void main(String[] args) {

    }


}
